import logging

import torch
import triton
import triton.language as tl

from flag_gems import runtime
from flag_gems.utils import dim_compress, libentry

logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
# def cfggen():
#     block_m = [1, 2, 4]
#     block_n = [128, 1024, 2048, 4096]
#     configs = [
#         triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)
#         for m in block_m
#         for n in block_n
#     ]
#     return configs


@libentry()
# @triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.heuristics(runtime.get_heuristic_config("index_add"))
# @triton.autotune(
#     configs=[], generate_configs="index_add", op_affiliation="cluster", row_sign="M", col_sign="N",
#     key=["M", "N"],
# )
@triton.jit
def index_add_kernel(
    inp,
    inp_cont,
    index,
    src,
    M: tl.constexpr,
    N: tl.constexpr,
    alpha,
    inp_len,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid_x = tl.program_id(axis=0)  # block_x
    pid_y = tl.program_id(axis=1)  # block_y
    rows_offsets = (
        pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
    )  # block_x * BLOCK_M + tl.arange(0, BLOCK_M)
    cols_offsets = pid_y * BLOCK_N + tl.arange(
        0, BLOCK_N
    )  # block_y * BLOCK_N + tl.arange(0, BLOCK_N)

    rows_mask = (
        rows_offsets < M
    )  # rows_mask = block_x * BLOCK_M + tl.arange(0, BLOCK_M) < M
    index_mask = (
        cols_offsets < N
    )  # index_mask = block_y * BLOCK_N + tl.arange(0, BLOCK_N) < N
    block_mask = rows_mask and index_mask  # block_mask = rows_mask and index_mask

    cur_indices = tl.load(
        index + cols_offsets, mask=index_mask, other=0
    )  # cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0)
    inp_off = (
        rows_offsets * inp_len + cur_indices[None, :]
    )  # inp_off = (block_x * BLOCK_M + tl.arange(0, BLOCK_M)) * M + cur_indices
    cur_inp = tl.load(
        inp + inp_off, mask=block_mask, other=0.0
    )  # cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0)
    src_off = (
        rows_offsets * N + cols_offsets[None, :]
    )  # src_off = (block_x * BLOCK_M + tl.arange(0, BLOCK_M)) * N + block_y * BLOCK_N + tl.arange(0, BLOCK_N)
    cur_src = tl.load(
        src + src_off, mask=block_mask, other=0.0
    )  # cur_src = tl.load(src + src_off, mask=block_mask, other=0.0)
    cur_inp += alpha * cur_src

    tl.store(inp_cont + inp_off, cur_inp, mask=block_mask)


def index_add(inp, dim, index, src, alpha=1):
    logger.debug("GEMS INDEX ADD")
    assert ((0 <= index) * (index < inp.size(dim))).equal(
        torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
    ), "0 <= index < self.size(dim)"
    assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
    assert index.numel() == src.size(
        dim
    ), "The dimth dimension of source must have the same size as the length of index"
    assert (
        inp.ndim == src.ndim
    ), "Self and source should have the same number of dimensions"
    assert (
        ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
    ), "src.size(d) == self.size(d) for all dimensions d != dim"

    inp = inp.contiguous()
    index = index.contiguous()
    src = src.contiguous()

    dim = dim % inp.ndim
    inp_len = inp.size(dim)
    N = index.numel()
    M = src.numel() // N
    fine_dim = inp.ndim - 1
    if dim != fine_dim:
        inp = dim_compress(inp, dim)
        src = dim_compress(src, dim)
    inp_cont = inp.clone()

    grid = lambda meta: (
        triton.cdiv(M, meta["BLOCK_M"]),
        triton.cdiv(N, meta["BLOCK_N"]),
    )
    index_add_kernel[grid](inp, inp_cont, index, src, M, N, alpha, inp_len)
    if dim != fine_dim:
        order = [i for i in range(inp_cont.ndim - 1)]
        order.insert(dim, fine_dim)
        return inp_cont.permute(order).contiguous()
    else:
        return inp_cont


def index_add_(inp, dim, index, src, alpha=1):
    logger.debug("GEMS INDEX ADD_")
    assert ((0 <= index) * (index < inp.size(dim))).equal(
        torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
    ), "0 <= index < self.size(dim)"
    assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
    assert index.numel() == src.size(
        dim
    ), "The dimth dimension of source must have the same size as the length of index"
    assert (
        inp.ndim == src.ndim
    ), "Self and source should have the same number of dimensions"
    assert (
        ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
    ), "src.size(d) == self.size(d) for all dimensions d != dim"

    inp_cont = inp.clone()
    inp_cont = inp_cont.contiguous()
    index = index.contiguous()
    src = src.contiguous()

    dim = dim % inp_cont.ndim
    inp_len = inp_cont.size(dim)
    N = index.numel()
    M = src.numel() // N
    fine_dim = inp_cont.ndim - 1
    if dim != fine_dim:
        inp_cont = dim_compress(inp_cont, dim)
        src = dim_compress(src, dim)

    grid = lambda meta: (
        triton.cdiv(M, meta["BLOCK_M"]),
        triton.cdiv(N, meta["BLOCK_N"]),
    )
    index_add_kernel[grid](inp_cont, inp_cont, index, src, M, N, alpha, inp_len)
    if dim != fine_dim:
        order = [i for i in range(inp_cont.ndim - 1)]
        order.insert(dim, fine_dim)
        inp_cont = inp_cont.permute(order).contiguous()
        inp.copy_(inp_cont)
        return inp
    else:
        inp.copy_(inp_cont)
        return inp
